{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1201 13:45:05.513083 139878915139392 file_utils.py:39] PyTorch version 0.4.1 available.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0,'./data')\n",
    "\n",
    "import time\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim\n",
    "import torch.utils.data\n",
    "from torch import nn\n",
    "from experiment.utils import *\n",
    "#from models.decoder_with_attention import DecoderWithAttention\n",
    "from models.decoder_with_attention_final import DecoderWithAttention\n",
    "from transformers import (WEIGHTS_NAME, BertConfig,\n",
    "                                  BertForSequenceClassification, BertTokenizer,\n",
    "                                  )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['/device:CPU:0', '/device:XLA_CPU:0', '/device:XLA_GPU:0', '/device:GPU:0']\n"
     ]
    }
   ],
   "source": [
    "def _get_available_devices():\n",
    "    from tensorflow.python.client import device_lib\n",
    "    local_device_protos = device_lib.list_local_devices()\n",
    "    return [x.name for x in local_device_protos]\n",
    "\n",
    "print(_get_available_devices())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1201 13:45:06.667576 139878915139392 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
      "I1201 13:45:06.668815 139878915139392 configuration_utils.py:168] Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": \"sst-2\",\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"num_labels\": 1,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_bfloat16\": false,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "I1201 13:45:07.610342 139878915139392 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
     ]
    }
   ],
   "source": [
    "from load_datasets_final import CaptionDataset\n",
    "from create_inputs_utils import (CaptionProcessor,)\n",
    "import create_inputs_utils_test\n",
    "\n",
    "#from data.load_datasets_final import *\n",
    "# from data.create_inputs_utils import *\n",
    "\n",
    "from experiment._train_one_epoch import *\n",
    "from experiment._validation_one_epoch import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data parameters\n",
    "data_folder = 'preprocessed_data'  # folder with data files saved by create_input_files.py\n",
    "data_name = 'preprocessed_coco'  # base name shared by data files\n",
    "\n",
    "# Model parameters\n",
    "emb_dim = 768 #1024  # dimension of word embeddings > change bert embedding size (768)\n",
    "attention_dim = 384  # dimension of attention linear layers\n",
    "decoder_dim = 384  # dimension of decoder RNN\n",
    "dropout = 0.5\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")  # sets device for model and PyTorch tensors\n",
    "cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead\n",
    "\n",
    "# Training parameters\n",
    "start_epoch = 0\n",
    "epochs = 50  # number of epochs to train for (if early stopping is not triggered)\n",
    "epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU\n",
    "batch_size = 20\n",
    "workers = 0 # for data-loading; right now, only 1 works with h5py\n",
    "best_bleu4 = 0.  # BLEU-4 score right now\n",
    "# print_freq = 100  # print training/validation stats every __ batches\n",
    "checkpoint = 'ckpt/BERT_3.pth.tar'  # path to checkpoint, None if none"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#checkpoint = None  # path to checkpoint, None if none\n",
    "BERT_VOCA_SIZE = 30522"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/torch/serialization.py:425: SourceChangeWarning: source code of class 'models.decoder_with_attention_final.DecoderWithAttention' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n",
      "/usr/local/lib/python3.6/dist-packages/torch/serialization.py:425: SourceChangeWarning: source code of class 'transformers.modeling_bert.BertModel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n"
     ]
    }
   ],
   "source": [
    "# Read word map\n",
    "#word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')\n",
    "#with open(word_map_file, 'r') as j:\n",
    "#    word_map = json.load(j)\n",
    "\n",
    "# Initialize / load checkpoint\n",
    "if checkpoint is None:\n",
    "    decoder = DecoderWithAttention(attention_dim=attention_dim,\n",
    "                                   embed_dim=emb_dim,\n",
    "                                   decoder_dim=decoder_dim,\n",
    "                                   vocab_size=BERT_VOCA_SIZE,\n",
    "                                   dropout=dropout)\n",
    "\n",
    "    decoder_optimizer = torch.optim.Adamax(params=filter(lambda p: p.requires_grad, decoder.parameters()))\n",
    "\n",
    "else:\n",
    "    checkpoint = torch.load(checkpoint)\n",
    "    start_epoch = checkpoint['epoch'] + 1\n",
    "    epochs_since_improvement = checkpoint['epochs_since_improvement']\n",
    "    best_bleu4 = checkpoint['bleu-4']\n",
    "    decoder = checkpoint['decoder']\n",
    "    decoder_optimizer = checkpoint['decoder_optimizer']\n",
    "    \n",
    "    \n",
    "# Move to GPU, if available\n",
    "decoder = decoder.to(device)\n",
    "\n",
    "# Loss functions\n",
    "criterion_ce = nn.CrossEntropyLoss().to(device)\n",
    "criterion_dis = nn.MultiLabelMarginLoss().to(device)\n",
    "\n",
    "# Custom dataloaders suffle 빼버림\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    CaptionDataset(data_folder, data_name, 'TRAIN'),\n",
    "    batch_size=batch_size,shuffle=False,num_workers=workers, pin_memory=True)\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    CaptionDataset(data_folder, data_name, 'VAL'),\n",
    "    batch_size=batch_size,shuffle=False,num_workers=workers, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: [3][0/28322]\tBatch Time 0.227 (0.227)\tData Load Time 0.009 (0.009)\tLoss 40.4218 (40.4218)\tTop-5 Accuracy 43.515 (43.515)\n",
      "Epoch: [3][100/28322]\tBatch Time 0.160 (0.168)\tData Load Time 0.005 (0.012)\tLoss 4.8236 (12.5110)\tTop-5 Accuracy 80.083 (61.759)\n",
      "Epoch: [3][200/28322]\tBatch Time 0.159 (0.168)\tData Load Time 0.005 (0.012)\tLoss 4.1100 (10.9873)\tTop-5 Accuracy 75.877 (67.896)\n",
      "Epoch: [3][300/28322]\tBatch Time 0.179 (0.171)\tData Load Time 0.009 (0.015)\tLoss 13.7990 (10.1687)\tTop-5 Accuracy 69.478 (71.285)\n",
      "Epoch: [3][400/28322]\tBatch Time 0.151 (0.171)\tData Load Time 0.013 (0.014)\tLoss 5.5907 (9.5902)\tTop-5 Accuracy 83.190 (73.187)\n",
      "Epoch: [3][500/28322]\tBatch Time 0.144 (0.172)\tData Load Time 0.010 (0.015)\tLoss 12.9849 (9.3125)\tTop-5 Accuracy 71.053 (74.566)\n",
      "Epoch: [3][600/28322]\tBatch Time 0.150 (0.172)\tData Load Time 0.010 (0.015)\tLoss 8.4412 (8.9286)\tTop-5 Accuracy 80.508 (75.804)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-7-440739b85566>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     13\u001b[0m           \u001b[0mcriterion_dis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcriterion_dis\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m           \u001b[0mdecoder_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdecoder_optimizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m           epoch=epoch)\n\u001b[0m\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/notebooks/project_cs470/cs470_project_version2/experiment/_train_one_epoch.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(train_loader, decoder, criterion_ce, criterion_dis, decoder_optimizer, epoch)\u001b[0m\n\u001b[1;32m     75\u001b[0m         \u001b[0;31m# Back prop.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m         \u001b[0mdecoder_optimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m         \u001b[0;31m# Clip gradients when they are getting too large\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m     91\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m         \"\"\"\n\u001b[0;32m---> 93\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     88\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     89\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(start_epoch, epochs):\n",
    "    \n",
    "    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20\n",
    "    if epochs_since_improvement == 20:\n",
    "        break\n",
    "    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:\n",
    "        adjust_learning_rate(decoder_optimizer, 0.8)\n",
    "\n",
    "    # One epoch's training\n",
    "    train(train_loader=train_loader,\n",
    "          decoder=decoder,\n",
    "          criterion_ce = criterion_ce,\n",
    "          criterion_dis=criterion_dis,\n",
    "          decoder_optimizer=decoder_optimizer,\n",
    "          epoch=epoch)\n",
    "\n",
    "\n",
    "    # One epoch's validation\n",
    "    recent_bleu4 = validate(val_loader=val_loader,\n",
    "                            decoder=decoder,\n",
    "                            criterion_ce=criterion_ce,\n",
    "                            criterion_dis=criterion_dis,)\n",
    "    \n",
    "    # Check if there was an improvement\n",
    "    is_best = recent_bleu4 > best_bleu4\n",
    "    best_bleu4 = max(recent_bleu4, best_bleu4)\n",
    "    if not is_best:\n",
    "        epochs_since_improvement += 1\n",
    "        print(\"\\nEpochs since last best performance: %d\\n\" % (epochs_since_improvement,))\n",
    "    else:\n",
    "        epochs_since_improvement = 0\n",
    "\n",
    "    # Save checkpoint\n",
    "    save_checkpoint(data_name, epoch, epochs_since_improvement, decoder,decoder_optimizer, recent_bleu4, is_best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
